package leetcode

import (
	"math"

	"github.com/halfrost/LeetCode-Go/template"
)

func minMalwareSpread2(graph [][]int, initial []int) int {
	if len(initial) == 0 {
		return 0
	}
	uf, minIndex, count, countMap, malwareMap, infectMap := template.UnionFind{}, initial[0], math.MinInt64, map[int]int{}, map[int]int{}, map[int]map[int]int{}
	for _, v := range initial {
		malwareMap[v]++
	}
	uf.Init(len(graph))
	for i := range graph {
		for j := range graph[i] {
			if i == j {
				break
			}
			if graph[i][j] == 1 && malwareMap[i] == 0 && malwareMap[j] == 0 {
				uf.Union(i, j)
			}
		}
	}
	for i := 0; i < len(graph); i++ {
		countMap[uf.Find(i)]++
	}
	// 记录每个集合和直接相邻病毒节点的个数
	for _, i := range initial {
		for j := 0; j < len(graph); j++ {
			if malwareMap[j] == 0 && graph[i][j] == 1 {
				p := uf.Find(j)
				if _, ok := infectMap[p]; ok {
					infectMap[p][i] = i
				} else {
					tmp := map[int]int{}
					tmp[i] = i
					infectMap[p] = tmp
				}
			}
		}
	}
	// 选出病毒节点中序号最小的
	for _, v := range initial {
		minIndex = min(minIndex, v)
	}
	for i, v := range infectMap {
		// 找出只和一个病毒节点相连通的
		if len(v) == 1 {
			tmp := countMap[uf.Find(i)]
			keys := []int{}
			for k := range v {
				keys = append(keys, k)
			}
			if count == tmp && minIndex > keys[0] {
				minIndex = keys[0]
			}
			if count < tmp {
				minIndex = keys[0]
				count = tmp
			}
		}
	}
	return minIndex
}

func min(a int, b int) int {
	if a > b {
		return b
	}
	return a
}
